Skip to content

Fix CUDA graph parameter grad lifetime#2937

Open
buptzyb wants to merge 3 commits into
NVIDIA:mainfrom
buptzyb:fix/cudagraph-wgrad-lifetime
Open

Fix CUDA graph parameter grad lifetime#2937
buptzyb wants to merge 3 commits into
NVIDIA:mainfrom
buptzyb:fix/cudagraph-wgrad-lifetime

Conversation

@buptzyb
Copy link
Copy Markdown
Contributor

@buptzyb buptzyb commented Apr 28, 2026

Summary

Fix CUDA graph replay so parameter gradients returned from Graphed.backward do not expose CUDA graph static buffers to downstream autograd users.

The fix clones returned parameter gradients before handing them back to autograd, while preserving the existing aliasing behavior for delayed-wgrad parameters marked with skip_backward_post_hook.

Root Cause

When CUDA graph replay returns parameter grad slots directly from static graph buffers, downstream autograd users can retain references to buffers that are overwritten by later graph replays. This can corrupt retained grads or break gradient accumulation semantics.

This is related to PyTorch issue pytorch/pytorch#181723.

Changes

  • Detect parameter grad slots in the graphed autograd input surface.
  • Clone returned non-delayed-wgrad parameter grads before returning from Graphed.backward.
  • Allow reused graph input/output buffer mode to weak-ref current parameter grad static buffers after capture because returned grads are now cloned.
  • Add CUDA graph tests for owned returned parameter grads, accumulation, delayed-wgrad alias preservation, and reused buffer interleaved pipeline replay.

Signed-off-by: Robin Zhang <robinz@nvidia.com>
@buptzyb buptzyb force-pushed the fix/cudagraph-wgrad-lifetime branch from 4cc8b89 to 6e16c63 Compare April 28, 2026 08:18
@buptzyb buptzyb marked this pull request as ready for review April 28, 2026 08:20
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 28, 2026

Greptile Summary

This PR fixes a lifetime bug in CUDA graph replay where parameter gradients returned from Graphed.backward were direct views of static graph buffers, meaning a later graph replay could silently overwrite retained .grad tensors or break gradient accumulation. The fix clones returned parameter gradients at replay time, with a snapshot of the per-parameter clone policy captured at graph-capture time.

  • Core fix in graph.py: _returned_param_grad_clone_slots() computes a per-slot boolean tuple at capture time; Graphed.backward uses it to .detach().clone() parameter grad slots before returning.
  • _reuse_graph_input_output_buffers integration: Parameter grad static buffers can now be safely weak-refed at capture time for the reused-buffer pipeline path.
  • New tests cover owned-grad isolation, accumulation correctness, delayed-wgrad alias preservation, opt-out via _clone_param_grads_on_return=False, policy snapshotting, and interleaved pipeline runs with weight equality checks.

Confidence Score: 5/5

Safe to merge — the fix correctly isolates CUDA graph static buffers from downstream autograd users without breaking existing aliasing semantics for delayed-wgrad parameters.

The slot offset arithmetic matches how static_input_surface is constructed, the policy snapshot is captured exactly once at graph-capture time, stream ordering around bwd_graph.replay() ensures buffers are fully written before cloning, and tests cover all meaningful behavioral combinations confirmed passing on H100.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/graph.py Adds _clone_param_grads_on_return flag and _returned_param_grad_clone_slots() snapshot helper; clones parameter grad slots in Graphed.backward; correctly weak-refs cloned slots in the reused-buffer path.
tests/pytorch/test_cuda_graphs.py Adds five targeted unit tests for the new clone behavior and updates the interleaved pipeline helper to return final weights; all existing and new tests pass on H100.

Reviews (5): Last reviewed commit: "Add CUDA graph param grad clone opt-out" | Re-trigger Greptile

Comment thread transformer_engine/pytorch/graph.py Outdated
Comment on lines +410 to +417
def _is_returned_param_grad_slot(idx, static_grad_inputs, module_params):
"""Return whether a static grad slot is consumed through Graphed.backward."""
module_param_start = len(static_grad_inputs) - len(module_params)
if idx < module_param_start:
return False
return not getattr(
module_params[idx - module_param_start], "skip_backward_post_hook", False
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Timing inconsistency between capture and replay attribute reads

_is_returned_param_grad_slot reads skip_backward_post_hook live at both capture time (line 748) and replay time (line 945). If a caller flips the attribute between those two points, the weak-ref decision at capture and the clone decision at replay get out of sync.

Specifically, if the attribute was False at capture (→ static buffer was weak-refed in the _reuse_graph_input_output_buffers path) but True at replay (→ code calls .detach() instead of .detach().clone()), the returned tensor is a detached view of an already-released weak-ref buffer whose memory may have been reused. Snapshotting the skip_backward_post_hook state once at capture time and storing it alongside the static grad slot (or asserting it is unchanged at replay) would make the contract explicit.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 4077b85. The parameter-grad clone policy is now snapshotted at capture time and passed into Graphed.backward, so replay no longer re-reads skip_backward_post_hook. Added test_make_graphed_callables_snapshots_parameter_grad_clone_policy to cover changing the attribute after capture.

Comment on lines +889 to +906
def test_make_graphed_callables_with_interleaved_pipeline_parallelism_reused_buffers(
*,
model_config: str = "small",
dtype: torch.dtype = torch.float16,
) -> None:
"""Test CUDA graphs with reused input/output buffers."""
model_config = model_configs[model_config]
kwargs = dict(model_config=model_config, dtype=dtype)
outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=False,
**kwargs,
)
graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=True,
reuse_graph_input_output_buffers=True,
**kwargs,
)
assert_all_equal(outputs, graph_outputs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Reused-buffer test only validates forward outputs, not gradient correctness

test_make_graphed_callables_with_interleaved_pipeline_parallelism_reused_buffers compares output_snapshots (forward tensors cloned before the corresponding backward) against the eager baseline. If the clone-on-return logic in Graphed.backward had a bug specifically in the _reuse_graph_input_output_buffers + pipeline path (e.g., gradient accumulation or an incorrect static buffer being read), weights would diverge but the test would still pass. A weight-equality check after one full schedule would strengthen confidence in the gradient path for this mode.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 4077b85. The interleaved pipeline helper now returns final weights in addition to outputs, and the reused-buffer test compares graph/eager final weights to cover gradient correctness. Full tests/pytorch/test_cuda_graphs.py passed on H100: 415 passed, 423 skipped.

@buptzyb buptzyb force-pushed the fix/cudagraph-wgrad-lifetime branch 2 times, most recently from 441e419 to beff9c1 Compare April 28, 2026 09:40
Signed-off-by: Robin Zhang <robinz@nvidia.com>
@buptzyb buptzyb force-pushed the fix/cudagraph-wgrad-lifetime branch from beff9c1 to 4077b85 Compare April 28, 2026 13:24
@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented May 12, 2026

I have a problem with this PR. On one hand I agree that there is a bug here and we should ensure that the gradient buffers are not overwritten before being applied. On the other hand for the cases without gradient accumulation (or when the accumulation is done in a different way, like in Megatron) it is a performance loss. Could we make this behavior optional - have it on by default, but also provide an opt-out option with a clear warning that states where this would be applicable?

@ptrendx ptrendx self-assigned this May 12, 2026
@buptzyb buptzyb requested a review from ksivaman as a code owner May 13, 2026 04:08
Signed-off-by: Robin Zhang <robinz@nvidia.com>
@buptzyb buptzyb force-pushed the fix/cudagraph-wgrad-lifetime branch from 1a3c2c5 to c8b2ee3 Compare May 13, 2026 04:11
@buptzyb
Copy link
Copy Markdown
Contributor Author

buptzyb commented May 13, 2026

@ptrendx I added a new argument _clone_param_grads_on_return (default = True). Documented that disabling this may improve performance but no longer have standard PyTorch returned-gradient lifetime semantics. Does this work for you?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants